#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 14 23:07:33 2021

@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob
import datetime as dt
import pandas as pd
from boto3.session import Session

import matplotlib as mpl
from matplotlib import path
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.gridspec as gridspec
import matplotlib.colors as colors
import matplotlib.cm as cm
# from geopy import distance

import pyproj
import wrf
from netCDF4 import Dataset
import pyart
from scipy.interpolate import griddata
from scipy.ndimage import median_filter
from pyproj import Geod
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
from cartopy.io.img_tiles import GoogleTiles
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter


import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

# AWS credentials
ACCESS_KEY='xxxxx'
SECRET_KEY='xxxxx'

# # fire name
# firename = 'bear_fire'
# # fuel code
# fuelstr = '_FUELx1'
# # radar site
# site = 'KBBX'
# # Cross section coords
# start_coord = (39.87, -120.9)
# end_coord = (39.55, -121.4)
# # PPI extent
# extent = [-121.5,-120.9, 39.48,39.9] #bear fire
# # workaround for perimeter issue
# start_thresh = 2 
# # sweep (elevation) to look at radar velocity
# rad_sweep = 3
# # radar grid limits
# grid_lims = ((0,10000),(0,160000),(0,160000))
# # PDF plot x/y limits
# PDFlims = [0,.2,-30,10] #ymin, ymax, xmin, xmax

# fire name
firename = 'caldor_fire'
# fuel code
fuelstr = '_FUELx1'
# radar site
site = 'KDAX'
# Cross section coords
start_coord = (38.81, -120.3)
end_coord = (38.5, -120.68)
# PPI extent
extent = [-120.71, -120.2, 38.42, 38.85] #caldor fire
# workaround for perimeter issue
start_thresh = 3 
# sweep (elevation) to look at radar velocity
rad_sweep = 1
# radar grid limits
grid_lims = ((0,10000),(-40000,120000),(20000,180000)) #z,y,x
# PDF plot x/y limits
PDFlims = [0,.3,-20,20] #ymin, ymax, xmin, xmax

# height index to plot for wrf radial winds. Corresponds to heights in z1d
radWindHgtIdx = 6
# Layer average WRF radial winds
AvgPPI = True
# Plot the PDF curve
PlotPDF = False

keyword = fuelstr+'_' #for naming plots
mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
filepath = mainpath+'/data/'+firename+'/'+fuelstr #data files location
savepath = mainpath+'/plots/'+firename+'/4panel_radar_PDF' #where to save figure
filelist_og = sorted(glob.glob(filepath+'/wrfout_d03*'))

filelist = []
for f in filelist_og:
    filelist.append(os.path.basename(f))
    
# # Verify local directories
# print('\nSetting up '+keyword+' plots...')
# print('Mainpath: '+mainpath)
# print('Filepath: '+filepath)
# print('Savepath: '+savepath)
# # Check if directories exist, if not make them
# if not os.path.exists(mainpath):
#     sys.exit('Directory: '+mainpath+' does not exist!')
# if not os.path.exists(filepath):
#     sys.exit('Directory: '+filepath+' does not exist!')
# if not os.path.exists(savepath):
#     os.makedirs(savepath)

########################################
# Functions
########################################
def radialWind(uVal,vVal,bearing):
    """
    Get angle difference between environmental wind and cross
    section angle (All angles assumed starting at N on 0-360 plane)
    """
    D = bearing - np.rad2deg((np.arctan2(uVal,vVal)))
    pWind = (np.sqrt(uVal**2+vVal**2))*np.cos(np.deg2rad(D))
    return pWind

def planarWind(uVal,vVal,bearing):
    """
    Get angle difference between environmental wind and cross
    section angle (All angles assumed starting at N on 0-360 plane)
    """
    D = bearing - np.rad2deg((np.arctan2(uVal,vVal)))
    pWind = (np.sqrt(uVal**2+vVal**2))*np.cos(np.deg2rad(D))
    return pWind

def relax_zone_remover (input, sr):
    # remove extra points in fire grid
    output = input
    for _ in range(sr):
        output = np.delete(output, -1, 0)
        output = np.delete(output, -1, 1)
    return output

def calculate_initial_compass_bearing(pointA, pointB):
    """
    Calculates the bearing between two points.
    The formulae used is the following:
        θ = atan2(sin(Δlong).cos(lat2),
                  cos(lat1).sin(lat2) − sin(lat1).cos(lat2).cos(Δlong))
    :Parameters:
      - `pointA: The tuple representing the latitude/longitude for the
        first point. Latitude and longitude must be in decimal degrees
      - `pointB: The tuple representing the latitude/longitude for the
        second point. Latitude and longitude must be in decimal degrees
    :Returns:
      The bearing in degrees
    :Returns Type:
      float
    """
    if (type(pointA) != tuple) or (type(pointB) != tuple):
        raise TypeError("Only tuples are supported as arguments")

    lat1 = np.deg2rad(pointA[0])
    lat2 = np.deg2rad(pointB[0])

    diffLong = np.deg2rad(pointB[1] - pointA[1])

    x = np.sin(diffLong) * np.cos(lat2)
    y = np.cos(lat1) * np.sin(lat2) - (np.sin(lat1)
            * np.cos(lat2) * np.cos(diffLong))

    initial_bearing = np.arctan2(x, y)

    # Now we have the initial bearing but math.atan2 return values
    # from -180° to + 180° which is not what we want for a compass bearing
    # The solution is to normalize the initial bearing as shown below
    initial_bearing = np.rad2deg(initial_bearing)
    compass_bearing = (initial_bearing + 360) % 360

    return compass_bearing


def bbox2ij(lon,lat,bbox=[-160., -155., 18., 23.]):
    """Return indices for i,j that will completely cover the specified bounding box.
    i0,i1,j0,j1 = bbox2ij(lon,lat,bbox)
    lon,lat = 2D arrays that are the target of the subset
    bbox = list containing the bounding box: [lon_min, lon_max, lat_min, lat_max]

    Example
    -------
    >>> i0,i1,j0,j1 = bbox2ij(lon_rho,[-71, -63., 39., 46])
    >>> h_subset = nc.variables['h'][j0:j1,i0:i1]
    """
    bbox = np.array(bbox)
    mypath = np.array([bbox[[0,1,1,0]],bbox[[2,2,3,3]]]).T
    p = path.Path(mypath)
    points = np.vstack((lon.flatten(),lat.flatten())).T
    n,m = np.shape(lon)
    inside = p.contains_points(points).reshape((n,m))
    ii,jj = np.meshgrid(range(m),range(n))
    return min(ii[inside]),max(ii[inside]),min(jj[inside]),max(jj[inside])

############## TESTING ##############
# filelist = ['wrfout_d03_2020-09-09_03:30:00']
# filelist = ['wrfout_d03_2021-08-17_23:15:00']

filepath = mainpath+'/data/'+firename
fuels = ['_FUELX1', '_FUELx4', '_FUELx8']

time_arr = []
for wf in range(len(filelist)):
    ########################################
    # Open WRF file and get timestamp
    ########################################
    print('Opening WRF output...', end='', flush=True)
    wrf_file = Dataset(filepath+'/_CONTROL/'+filelist[wf], mode='r')
    
    timestamp = wrf.getvar(wrf_file,'Times').values
    timestamp = pd.to_datetime(timestamp)
    timestamp = pd.Timestamp(timestamp)
    time_arr.append(timestamp)
    print(timestamp)
    wrf_file.close()
    
    ########################################
    # Fetch radar based on wrf timestamp
    ########################################
    print('Fetching radar files from AWS...', end='', flush=True)
    session = Session(aws_access_key_id=ACCESS_KEY,
                      aws_secret_access_key=SECRET_KEY)
    s3 = session.client('s3')
    
    minutes = []
    url = []
    bucket = 'noaa-nexrad-level2'
    prefix = dt.datetime.strftime(timestamp,'%Y/%m/%d')+'/'+site+'/'+site+dt.datetime.strftime(timestamp,'%Y%m%d_%H')
    rfilelist = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)['Contents']
    for key in range(len(rfilelist)):
        # Filter out 'MDM' files
        if (str(rfilelist[key]['Key'][-3:]) == 'V06'):
            minutes.append(int(rfilelist[key]['Key'][-8:-6]))
            url.append(rfilelist[key]['Key'])
    
    # Find nearest minute in hour
    minutes = np.asarray(minutes)
    idx = np.abs(minutes-timestamp.minute).argmin()
    url = url[idx][:-8]+str(minutes[idx]).zfill(2)
    print(url)
    
    new_url = s3.list_objects_v2(Bucket=bucket, Prefix=url)['Contents'][0]['Key']
    s3.download_file(bucket, new_url, '/tmp/radarfile_'+site)
    
    radarfiles = glob.glob('/tmp/radarfile*')
    for f in radarfiles:
        os.rename(f,mainpath+'/'+f[5:])
    radarfiles = glob.glob(mainpath+'/radarfile*')
    
    ########################################
    # Process Radar Data
    ########################################
    
    print('Processing radar data...', end='', flush=True)
    # create radar object and get radar coords
    radar_obj = pyart.io.read(mainpath+'/radarfile_'+site)
    radar_lat = radar_obj.latitude['data'][0]
    radar_lon = radar_obj.longitude['data'][0]
    
    # create a gate filter which specifies gates to exclude from dealiasing
    gatefilter = pyart.filters.GateFilter(radar_obj)
    gatefilter.exclude_transition()
    gatefilter.exclude_invalid('velocity')
    gatefilter.exclude_invalid('reflectivity')
    gatefilter.exclude_outside('reflectivity', 0, 75)
    
    # perform dealiasing
    dealias_data = pyart.correct.dealias_region_based(
        radar_obj, gatefilter=gatefilter)
    radar_obj.add_field('corrected_velocity', dealias_data)
    
    # pick which sweep to look at (elevation)
    sweep_val = rad_sweep
    sweep = radar_obj.get_slice(sweep_val)
    latlon = radar_obj.get_gate_lat_lon_alt(sweep_val)
    
    # make arrays of coords and data
    rvel = radar_obj.fields['corrected_velocity']['data'][sweep]
    rlat = latlon[0]
    rlon = latlon[1]
    ralt = latlon[2]
    
    radartime = radar_obj.time['units'].split(' ')[-1]
    
    print('Constructing radar x-sect...', end='', flush=True)
    
    # Read file and create radar grid
    # Create interpolated grid obj from radar obj 
    radar_grid = pyart.map.grid_from_radars(radar_obj,
                                       grid_shape=(40,800,800),
                                       grid_limits=grid_lims,
                                       grid_origin=(radar_lat,radar_lon),
                                       fields=['corrected_velocity'],
                                       constant_roi = 1100.0,
                                       )
    # Save the variables
    tmpalt = radar_grid.point_altitude['data']
    tmplon = radar_grid.get_point_longitude_latitude()[0]
    tmplat = radar_grid.get_point_longitude_latitude()[1]
    vel = radar_grid.fields['corrected_velocity']['data']
    
    # create 3d arrays of coords
    lon = np.ones(vel.shape)
    lat = np.ones(vel.shape)
    for ii in range(len(vel[:,0,0])):
        lat[ii,:,:] = lat[ii,:,:]*tmplat
        lon[ii,:,:] = lon[ii,:,:]*tmplon
    
    # create 2d array of start coords
    slat = np.ones(tmplon.shape)*start_coord[0]
    slon = np.ones(tmplon.shape)*start_coord[1]
    # Find min dist between desired points
    geodesic = pyproj.Geod(ellps='WGS84')
    fwd_azimuth,back_azimuth,distance = geodesic.inv(slon, slat, # lon, lat
                                                     tmplon, tmplat)
    start_idx = np.argwhere(distance == distance.min()) #index of min points
    
    # create 2d array of end coords
    elat = np.ones(tmplon.shape)*end_coord[0]
    elon = np.ones(tmplon.shape)*end_coord[1]
    # Find min dist between desired points
    geodesic = pyproj.Geod(ellps='WGS84')
    fwd_azimuth,back_azimuth,distance = geodesic.inv(elon, elat, # lon, lat
                                                     tmplon, tmplat)
    end_idx = np.argwhere(distance == distance.min()) #index of min points
    
    start = (start_idx.squeeze()[1],start_idx.squeeze()[0])
    end = (end_idx.squeeze()[1],end_idx.squeeze()[0])
    # interpolated x-section line
    xy_line = wrf.xy(vel, start_point=start, end_point=end)
    # cross section along line
    # vert_cross.append(wrf.interp2dxy(vel, xy_line))
    vel_interp = wrf.interp2dxy(vel, xy_line)
    lon_interp = wrf.interp2dxy(lon, xy_line)
    lat_interp = wrf.interp2dxy(lat, xy_line)
    x_interp = np.ones(lat_interp.shape)
    z_interp = np.ones(lat_interp.shape)
    for ii in range(len(x_interp[:,0])):
        x_interp[ii,:] = x_interp[ii,:]*range(len(lat_interp[ii,:]))
        z_interp[ii,:] = z_interp[ii,:]*tmpalt[ii,0,0]
    
    # flip x directions
    vel_interp = np.flip(vel_interp,axis=1)
    lon_interp = np.flip(lon_interp,axis=1)
    lat_interp = np.flip(lat_interp,axis=1)
    # x_interp = np.flip(x_interp,axis=1)
    z_interp = np.flip(z_interp,axis=1)
    
    coord_pairs_rad = zip(lon_interp[0,:], lat_interp[0,:])
    # z_interp = wrf.interp2dxy(tmpalt, xy_line)
    
    # print(vel.shape)
    # print(vel_interp.shape)
    
    # print(lon.shape)
    # print(lon_interp.shape)
    
    # print(lat.shape)
    # print(lat_interp.shape)
    
    # print(tmpalt.shape)
    # print(z_interp.shape)
    
    # print(start_coord[1], start_coord[0])
    # print(end_coord[1], end_coord[0])
    # print([lon_interp[0,0].data,lat_interp[0,0].data])
    # print([lon_interp[0,-1].data,lat_interp[0,-1].data])
    
    print('Done.')
    
    ########################################
    # Process WRF variables
    ########################################
    print('Processing WRF output...')

    plot_wrfWindppi = []
    plot_wrfWindrhi = []
    
    for ff in fuels:
        
        wrf_file = Dataset(filepath+'/'+ff+'/'+filelist[wf],mode='r')
    
        # Standard lat/lon grid
        terrain = wrf.getvar(wrf_file, "HGT", timeidx=-1)
        lat = wrf.getvar(wrf_file, "XLAT", timeidx=-1)
        lon = wrf.getvar(wrf_file, "XLONG", timeidx=-1)
        
        # Heights
        phb = wrf.getvar(wrf_file, "PHB", timeidx=-1)/9.80665 #divide by gravity to get height in meters
        ph = wrf.getvar(wrf_file, "PH", timeidx=-1)/9.80665
        z = ph+phb
        
        # create height array for plotting
        z = z[1:,:,:]
        
        # Smoke
        smoke = wrf.getvar(wrf_file, "fire_smoke", timeidx=-1)
        
        # Winds
        u = wrf.getvar(wrf_file, "U", timeidx=-1)
        v = wrf.getvar(wrf_file, "V", timeidx=-1)
        w = wrf.getvar(wrf_file, "W", timeidx=-1)
        
        # Destagger appropriate dimensions
        u = wrf.destagger(u,-1)
        v = wrf.destagger(v,1)
        w = wrf.destagger(w,0)
        
        u[smoke < 1e-12] = np.nan
        v[smoke < 1e-12] = np.nan
        w[smoke < 1e-12] = np.nan
        
        # Rotates u/v to earth coords    
        mod_stand_lon = wrf.extract_global_attrs(wrf_file,'STAND_LON')['STAND_LON']
        mod_cen_lon = wrf.extract_global_attrs(wrf_file,'CEN_LON')['CEN_LON']
        mod_cen_lat = wrf.extract_global_attrs(wrf_file,'CEN_LAT')['CEN_LAT']
        
        winds = wrf.uvmet(u,v,lat,lon,mod_stand_lon,1)
        u = winds[0,:]
        v = winds[1,:]
        
        # interpolate u/v to constant heights MSL
        z1d = np.arange(0,5250,250)
        ui = wrf.vinterp(wrf_file, u,'ght_msl',z1d/1000.)
        vi = wrf.vinterp(wrf_file, v,'ght_msl',z1d/1000.)
        wi = wrf.vinterp(wrf_file, w,'ght_msl',z1d/1000.)
        
        wd = 270. - (np.arctan2(v,u) * 180./np.pi)
        ws = np.sqrt(u**2+v**2)
        
        wd = np.asarray(wd)
        ws = np.asarray(ws)
        
        
        # For plotting fire perimeter
        # Fire grid variables
        lfn1 = wrf.getvar(wrf_file, 'LFN', timeidx=-1)
        # loading lat/lons of fire mesh
        yf1 = wrf.getvar(wrf_file, 'FXLAT', timeidx=-1)
        xf1 = wrf.getvar(wrf_file, 'FXLONG', timeidx=-1)
        
        # subgrid ratio
        sr = int(wrf_file.dimensions['west_east_subgrid'].size)/int(wrf_file.dimensions['west_east_stag'].size)
        # removing the relaxation zones of the level-set function
        lfn1 = relax_zone_remover(lfn1, int(sr))
        xf1 = relax_zone_remover(xf1, int(sr))
        yf1 = relax_zone_remover(yf1, int(sr))
        # match lfn to met grid
        lfn2 = lfn1[::int(sr),::int(sr)]
        # remove weird perimeters before fire starts
        if (wf < start_thresh):
            lfn2[:,:] = np.nan
            lfn1[:,:] = np.nan
        
        print('Interp x-sections...', end="", flush=True)
            
        # how many slices to take on either side of center x-section (0)
        # 51 total, 25 +/- centerline
        x_idx = np.arange(-25,26,1)
        y_idx = np.arange(-25,26,1)[::-1]
        x_idx = np.arange(-8,9,1)
        y_idx = np.arange(-8,9,1)[::-1]
        
        _lon_cross = []
        _lat_cross = []
        _lfn_cross = []
        _u_cross = []
        _v_cross = []
        _w_cross = []
        _smk_cross = []
        _ter_cross = []
        _comp_wind = []
        for ll in range(len(x_idx)):
            # print(ll)
            
            # Central cross section BEAR
            # Take absolute value because sometimes outputs negative indices???
            # negative index == out of bounds point
            start_pt = wrf.ll_to_xy(wrf_file, start_coord[0], start_coord[1])
            end_pt = wrf.ll_to_xy(wrf_file, end_coord[0], end_coord[1])
            
            # # Central cross section CALDOR
            # start_pt = wrf.ll_to_xy(wrf_file, 38.83, -120.3)
            # end_pt = wrf.ll_to_xy(wrf_file, 38.5, -120.68)
            
            # create list of start/end points to loop through
            startLat = lat[start_pt.data[1]+x_idx[ll],start_pt.data[0]+y_idx[ll]]
            startLon = lon[start_pt.data[1]+x_idx[ll],start_pt.data[0]+y_idx[ll]]
            # print(startLon.data,startLat.data)
        
            endLat = lat[end_pt.data[1]+x_idx[ll],end_pt.data[0]+y_idx[ll]]
            endLon = lon[end_pt.data[1]+x_idx[ll],end_pt.data[0]+y_idx[ll]]
            # print(endLon.data,endLat.data)
            
            # Define the cross section start and end points (NE to SW)
            start_point = wrf.CoordPair(lat=startLat, lon=startLon)
            end_point = wrf.CoordPair(lat=endLat, lon=endLon)
        
            # what height to interpolate x-sections to
            interp_height = np.arange(0,16030,30) # m
        
            # Fire vars
            lfn_cross = wrf.interpline(lfn2, wrfin = wrf_file,
                                       start_point = start_point,
                                       end_point = end_point,
                                       latlon = True, meta = False)
            # Standard vars
            ter_cross = wrf.interpline(terrain, wrfin = wrf_file,
                                       start_point = start_point,
                                       end_point = end_point,
                                       latlon = True, meta = False)
            lon_cross = wrf.interpline(lon, wrfin = wrf_file,
                                       start_point = start_point,
                                       end_point = end_point,
                                       latlon = True, meta = False)
            lat_cross = wrf.interpline(lat, wrfin = wrf_file,
                                       start_point = start_point,
                                       end_point = end_point,
                                       latlon = True, meta = False)
            u_cross = wrf.vertcross(u, z, levels=interp_height,
                                    wrfin = wrf_file,
                                    start_point = start_point,
                                    end_point = end_point,
                                    latlon = True, meta = True)
            v_cross = wrf.vertcross(v, z, levels=interp_height,
                                    wrfin = wrf_file,
                                    start_point = start_point,
                                    end_point = end_point,
                                    latlon = True, meta = True)
            w_cross = wrf.vertcross(w, z, levels=interp_height,
                                    wrfin = wrf_file,
                                    start_point = start_point,
                                    end_point = end_point,
                                    latlon = True, meta = True)
            smk_cross = wrf.vertcross(smoke, z, levels=interp_height,
                                      wrfin = wrf_file,
                                      start_point = start_point,
                                      end_point = end_point,
                                      latlon = True, meta = True)
        
            # Flip everything so NE is on the right and SW is on the left
            lfn_cross = lfn_cross[::-1]
            lon_cross = lon_cross[::-1]
            lat_cross = lat_cross[::-1]
            coord_pairs = zip(lon_cross, lat_cross) # x labels for plotting
            x_array = np.arange(len(ter_cross)) # x coordinates for plotting
            ter_cross = ter_cross[::-1]
            u_cross = u_cross[:,::-1] 
            v_cross = v_cross[:,::-1] 
            w_cross = w_cross[:,::-1]
            smk_cross = smk_cross[:,::-1]
        
            # Calculate bearing for in-plane winds
            geodesic = pyproj.Geod(ellps='WGS84')
            fwd_azimuth,back_azimuth,distance = geodesic.inv(start_point.lon, start_point.lat,
                                                             end_point.lon, end_point.lat)
            # Calculate in-plane wind
            component_wind = planarWind(u_cross,v_cross,back_azimuth)
            
            _lon_cross.append(lon_cross)
            _lat_cross.append(lat_cross)
            _lfn_cross.append(lfn_cross)
            _ter_cross.append(ter_cross)
            _u_cross.append(u_cross)
            _v_cross.append(v_cross)
            _w_cross.append(w_cross)
            _smk_cross.append(smk_cross)
            _comp_wind.append(component_wind)
            
            print(str(ll)+'.', end="", flush=True)
        
        # wrf cross section interpolates to nearest points so array lengths may 
        # vary by 1 or 2 points. This trims excess points off longer x-sections
        # (if they exist) so all dimensions match for calculations/stacking.    
        lens = []
        for i in _lon_cross:
            lens.append(len(i))
        lens = np.asarray(lens)
        minlength = np.nanmin(lens)
            
        # trim arrays to length of smallest array
        for i in range(len(_lon_cross)):
            _lon_cross[i] = _lon_cross[i][:minlength]
            _lat_cross[i] = _lat_cross[i][:minlength]
            _lfn_cross[i] = _lfn_cross[i][:minlength]
            _ter_cross[i] = _ter_cross[i][:minlength]
            _u_cross[i] = _u_cross[i][:,:minlength]
            _v_cross[i] = _v_cross[i][:,:minlength]
            _w_cross[i] = _w_cross[i][:,:minlength]
            _smk_cross[i] = _smk_cross[i][:,:minlength]
            _comp_wind[i] = _comp_wind[i][:,:minlength]
        
        # Turn lists of arrays into multi-dim arrays
        _lon_cross = np.stack(_lon_cross, axis=0)
        _lat_cross = np.stack(_lat_cross, axis=0)
        _lfn_cross = np.stack(_lfn_cross, axis=0)
        _ter_cross = np.stack(_ter_cross, axis=0)
        _u_cross = np.stack(_u_cross, axis=0)
        _v_cross = np.stack(_v_cross, axis=0)
        _w_cross = np.stack(_w_cross, axis=0)
        _smk_cross = np.stack(_smk_cross, axis=0)
        _comp_wind = np.stack(_comp_wind, axis=0)
        
        # find means/max
        lon_cross = np.nanmean(_lon_cross,axis=0).squeeze()
        lat_cross = np.nanmean(_lat_cross,axis=0).squeeze()
        lfn_cross = np.nanmin(_lfn_cross,axis=0).squeeze()
        ter_cross = np.nanmean(_ter_cross,axis=0).squeeze()
        u_cross = np.nanmean(_u_cross,axis=0).squeeze()
        v_cross = np.nanmean(_v_cross,axis=0).squeeze()
        w_cross = np.nanmean(_w_cross,axis=0).squeeze()
        smk_cross = np.nanmax(_smk_cross,axis=0).squeeze()
        comp_wind = np.nanmean(_comp_wind,axis=0).squeeze()
        x_array = np.arange(len(ter_cross))
        
        # Cross section and coords of fire area
        lfn_x = x_array[lfn_cross < 0]
        lfn_y = ter_cross[lfn_cross < 0]
        # lfn_cross = lfn_cross[lfn_cross < 0]
        
        print('\nClosing WRF file...')
        wrf_file.close()
        
        #%%
        ######################
        # WRF radial wind calc
        print('Radial wind calculations...')
        
        #### Plan View ####
        # create arrays of same dimensions  for calculations
        radar_lon_arr = (lon*0)+radar_lon
        radar_lat_arr = (lat*0)+radar_lat
        
        # Calculate bearing for WRF radial winds  (plan view)
        geodesic = pyproj.Geod(ellps='WGS84')
        fwd_azimuth,back_azimuth,distance = geodesic.inv(radar_lon_arr.data, radar_lat_arr.data, 
                                                          lon.data, lat.data)
        # calculate radial wind
        radwind_UV = radialWind(ui,vi,fwd_azimuth)
        
        if (AvgPPI == True):
            radwind_UVW = []
            # avg layer with proximal ppi levels
            for ii in [-2,-1,0,1,2]:
                # 2d height/distance arrays for accounting for z coordinate in WRF radial winds
                newidx = radWindHgtIdx+ii
                rdist = distance
                rhgts = np.ones(lon.shape) * z1d[newidx]
                
                # x/y/z plan view radial winds
                angles = 90. - np.rad2deg(np.arctan2(rhgts,rdist)) #shift quadrant
                radwind_UVW.append(radialWind(radwind_UV.data[newidx,:,:], wi.data[newidx,:,:], angles))
            
            radwind_UVW = np.stack(radwind_UVW, axis=0)
            radwind_UVW = np.nanmean(radwind_UVW,axis=0).squeeze()
        
        else:
            # 2d height/distance arrays for accounting for z coordinate in WRF radial winds
            rdist = distance
            rhgts = np.ones(lon.shape) * z1d[radWindHgtIdx]
            
            # x/y/z plan view radial winds
            angles = 90. - np.rad2deg(np.arctan2(rhgts,rdist)) #shift quadrant
            radwind_UVW = radialWind(radwind_UV.data[radWindHgtIdx,:,:], wi.data[radWindHgtIdx,:,:], angles)
        
        #### Cross Section ####
        # create arrays of same dimensions  for calculations
        radar_lon_arr = (lon_cross*0)+radar_lon
        radar_lat_arr = (lat_cross*0)+radar_lat
        
        # Calculate bearing for WRF radial winds 
        geodesic = pyproj.Geod(ellps='WGS84')
        fwd_azimuth,back_azimuth,distance = geodesic.inv(radar_lon_arr, radar_lat_arr, 
                                                          lon_cross, lat_cross)
        
        # create 3d height array from 1d array
        hgts = np.ones(np.shape(u_cross))
        for i in range(len(interp_height)):
            hgts[i,:] = hgts[i,:] * interp_height[i]
        dist = np.ones(np.shape(u_cross))
        for i in range(len(interp_height)):
            dist[i,:] = dist[i,:] * distance
            
        # x/y cross section radial wind
        radwind1 = radialWind(u_cross, v_cross, fwd_azimuth)
        
        # x/y/z cross section radial winds
        angles = np.rad2deg(np.arctan2(hgts,dist))+90. #shift quadrant
        newradwind = radialWind(radwind1, w_cross, angles)
        
        plot_wrfWindppi.append(radwind_UVW)
        plot_wrfWindrhi.append(newradwind)

    #%%    
    
    ######################
    # Plotting
    print('Plotting...')
    fig = plt.figure(1,figsize=(12,15))
    mpl.rcParams['axes.linewidth'] = 2
    spec = gridspec.GridSpec(nrows=60, ncols=40, figure=fig)
    
    # Create a Terrain instance.
    # esri_terrain = ShadedReliefESRI()
    
    ax1 = fig.add_subplot(spec[:10,:10], projection=ccrs.PlateCarree()) #rows,cols
    ax2 = fig.add_subplot(spec[11:21,:10], projection=ccrs.PlateCarree()) #rows,cols
    ax3 = fig.add_subplot(spec[22:32,:10], projection=ccrs.PlateCarree()) #rows,cols
    ax4 = fig.add_subplot(spec[33:43,:10], projection=ccrs.PlateCarree()) #rows,cols

    ax5 = fig.add_subplot(spec[:10,14:38]) #rows,cols
    ax6 = fig.add_subplot(spec[11:21,14:38]) #rows,cols
    ax7 = fig.add_subplot(spec[22:32,14:38]) #rows,cols
    ax8 = fig.add_subplot(spec[33:43,14:38]) #rows,cols
    
    axis_label_size = 14
    axis_label_weight = 'bold'
    tick_label_size = 12
    perim_linewidth = 4
    cmap1 = cm.get_cmap('seismic',lut=51)
    
    # create PDF histogram/plot
    ######################
    if (PlotPDF == True):
        ax9 = fig.add_subplot(spec[46:59,:38]) #rows,cols
        
        from scipy import stats
        i0,i1,j0,j1 = bbox2ij(rlon[:,:],rlat[:,:],extent)
        # Make 1d and remove masked values
        points_rvel = rvel[j0:j1,i0:i1].flatten().compressed()
        
        # Make PDF radial velocity range on dummy plot
        figdum = plt.figure(2,figsize=(1,1))
        axdum = figdum.add_subplot(111)
        histplot = axdum.hist(points_rvel, density=True, align='mid', 
                              bins=np.arange(PDFlims[2],PDFlims[3]+1,1))
        # Only plot PDF lines and not histogram
        ax9.plot(histplot[1][:-1], histplot[0], color='black', linewidth=4, linestyle='--', label='Observed')
        # Close dummy fig
        plt.close(fig=figdum)
        
        scenarios = ['Fuelx8', 'Fuelx4', 'Fuelx1']
        colors = ['maroon', 'mediumvioletred', 'dodgerblue']
        for xx in reversed(range(len(plot_wrfWindppi))):
            i0,i1,j0,j1 = bbox2ij(lon.data[:,:],lat.data[:,:],extent)
            points_wrfVel = plot_wrfWindppi[xx][j0:j1,i0:i1].flatten()
            # Convert NaNs to zero for the time with no fire
            if (np.isnan(np.nanmax(points_wrfVel)) == True):
                points_wrfVel[np.isnan(points_wrfVel) == True] = 0
            # Make dummy figure for histogram to fetch PDF data
            figdum = plt.figure(2,figsize=(1,1))
            axdum = figdum.add_subplot(111)
            histplot = axdum.hist(points_wrfVel, density=True, align='mid',
                                  bins=np.arange(PDFlims[2],PDFlims[3]+1,1))
            # PLot PDF data from histogram
            ax9.plot(histplot[1][:-1], histplot[0], color=colors[xx], linewidth=4, label=scenarios[xx])
            ax9.legend()
            # close dummy fig
            plt.close(fig=figdum)
        
        # PLot limits and styling
        ax9.set_ylim(PDFlims[0],PDFlims[1])
        ax9.set_xlim(PDFlims[2],PDFlims[3])
        
        # Set border and ticks around plot
        for axis in ['top', 'bottom', 'left', 'right']:
            ax9.spines[axis].set_linewidth(2)  # change width
            ax9.spines[axis].set_zorder(99)
        ax9.xaxis.set_tick_params(width=2,length=5,direction="in", labelsize=tick_label_size, zorder=99, bottom=True, top=True)
        ax9.yaxis.set_tick_params(width=2,length=5,direction="in", labelsize=tick_label_size, zorder=99, left=True, right=True)
        ax9.grid('on')
        
        ax9.set_ylabel("Probability", fontsize=axis_label_size, fontweight=axis_label_weight)
        ax9.set_xlabel("Radial Velocity", fontsize=axis_label_size, fontweight=axis_label_weight)
     
    # If not plotting PDF make a colorbar axis in its place
    if (PlotPDF == False):
        ax9 = fig.add_subplot(spec[48,5:33])   
        
    
    
    # create a radar ppi plot 
    ######################
    ax1.set_extent(extent, crs=ccrs.PlateCarree())
    p1 = ax1.pcolormesh(rlon,rlat,rvel, vmin=-25, vmax=25, cmap=cmap1)#,transform=ccrs.PlateCarree())
    perim1 = ax1.contour(xf1, yf1, lfn1, 0, linewidths=perim_linewidth, colors='k')
    perim = ax1.contour(xf1, yf1, lfn1, 0, linewidths=perim_linewidth-2, colors='orangered')
    xsec1 = ax1.plot([lon_interp[0,0],lon_interp[0,-1]], [lat_interp[0,0],lat_interp[0,-1]], linewidth=5, color='white', alpha=.5, zorder=11,transform=ccrs.PlateCarree())
    xsec = ax1.plot([lon_interp[0,0],lon_interp[0,-1]], [lat_interp[0,0],lat_interp[0,-1]], linewidth=3, linestyle='--', color='k', zorder=11,transform=ccrs.PlateCarree())
    
    ax1.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
    ax1.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
    # lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
    lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
    # ax1.xaxis.set_major_formatter(lon_formatter) # set lons
    ax1.yaxis.set_major_formatter(lat_formatter) # set lats
    ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=0, labelcolor='white', zorder=12, bottom=True, top=True)
    ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, zorder=12, left=True, right=True)
    
    #Download county shapefiles
    reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
    counties = list(reader.geometries())
    COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
    ax1.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=10)
    
    ax1.outline_patch.set_linewidth(2)
    ax1.outline_patch.set_zorder(99)

    ax1.set_ylabel("Latitude", fontsize=axis_label_size, fontweight=axis_label_weight)
    
    ax1.set_title(pd.Timestamp(radartime).strftime('%Y-%m-%d %X')+'\n' +
                  site+' Velocity PPI Sweep '+str(rad_sweep), fontsize=axis_label_size, fontweight=axis_label_weight)
        
    # synthetic radar ppi plot
    ######################
    # scenarios = ['Fuelx1', 'Fuelx4', 'Fuelx8']
    axvals = [ax2, ax3, ax4]
    for xx in reversed(range(len(plot_wrfWindppi))):
    
        axVal = axvals[xx]
        axVal.set_extent(extent, crs=ccrs.PlateCarree())
        p1 = axVal.pcolormesh(lon,lat,plot_wrfWindppi[xx], vmin=-25, vmax=25, cmap=cmap1,transform=ccrs.PlateCarree())
        perim1 = axVal.contour(xf1, yf1, lfn1, 0, linewidths=perim_linewidth, colors='k')
        perim = axVal.contour(xf1, yf1, lfn1, 0, linewidths=perim_linewidth-2, colors='orangered')
        xsec1 = axVal.plot([lon_interp[0,0],lon_interp[0,-1]], [lat_interp[0,0],lat_interp[0,-1]], linewidth=5, color='white', alpha=.5, zorder=11,transform=ccrs.PlateCarree())
        xsec = axVal.plot([lon_interp[0,0],lon_interp[0,-1]], [lat_interp[0,0],lat_interp[0,-1]], linewidth=3, linestyle='--', color='k', zorder=11,transform=ccrs.PlateCarree())
    
        axVal.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
        axVal.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
        lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
        lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
        axVal.xaxis.set_major_formatter(lon_formatter) # set lons
        axVal.yaxis.set_major_formatter(lat_formatter) # set lats
        if (axVal == ax4):
            axVal.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, rotation=20, bottom=True, top=True)
        else:
            axVal.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, labelcolor='white',rotation=20, bottom=True, top=True)
        
        axVal.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, left=True, right=True)
        
        #Download county shapefiles
        reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
        counties = list(reader.geometries())
        COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
        axVal.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=10)
        
        axVal.outline_patch.set_linewidth(2)
        axVal.outline_patch.set_zorder(99)
        
        axVal.set_ylabel("Latitude", fontsize=axis_label_size, fontweight=axis_label_weight)
        if (axVal == ax4):
            axVal.set_xlabel("Longitude", fontsize=axis_label_size, fontweight=axis_label_weight)
        
        # axVal.set_title(timestamp.strftime('%Y-%m-%d %X')+'\n' +
        #               'Simulated '+site+' Velocity PPI '+str(int(z1d[radWindHgtIdx]))+' m MSL', fontsize=12, fontweight='bold')
    
    # radar cross section
    ######################
    #create interpolated terrain from WRF for radar data
    xcoord = np.linspace(0,len(ter_cross), len(z_interp[0,:]))
    xp = np.linspace(0,len(ter_cross), len(ter_cross))
    fp = ter_cross
    radz = np.interp(xcoord, xp, fp)
    # fp = lfn_cross
    # radlfn = np.interp(xcoord, xp, fp)
    # rad_lfnx = xcoord[radlfn < 0]
    # rad_lfny = radz[radlfn < 0]
    
    # ax5 = fig.add_subplot(222)
    ax5.pcolormesh(x_interp,z_interp,vel_interp, vmin=-25, vmax=25, cmap=cmap1, zorder=0)
    # plot terrain
    ax5.plot(x_interp[0,:],radz,color='k',linewidth=4.,zorder=2)
    ax5.fill_between(x_interp[0,:],radz,y2=0, color='sienna', zorder=1)
    # Plot fire area
    ax5.plot(lfn_x*(len(x_interp[0,:])/len(lat_cross)),lfn_y,color='k',linewidth=perim_linewidth, zorder=3)
    ax5.plot(lfn_x*(len(x_interp[0,:])/len(lat_cross)),lfn_y,color='orangered',linewidth=perim_linewidth-2, zorder=4)

    # Set the x-ticks to use latitude and longitude labels.
    tick_space = 90 #higher values = fewer ticks
    x_ticks = np.arange(len(lat_cross))
    x_ticks = x_ticks[tick_space::tick_space]*(len(x_interp[0,:])/len(lat_cross))
    ax5.set_xticks(x_ticks)
    
    ax5.set_xlim(x_interp[0,0],x_interp[0,-1])
    ax5.set_ylim(0,9000)
    
    # Set border and ticks around plot
    for axis in ['top', 'bottom', 'left', 'right']:
        ax5.spines[axis].set_linewidth(2)  # change width
        ax5.spines[axis].set_zorder(99)
    
    ax5.xaxis.set_tick_params(width=2,length=5,direction="in", labelsize=0, zorder=99, bottom=True, top=True)
    ax5.yaxis.set_tick_params(width=2,length=5,direction="in", labelsize=tick_label_size, zorder=99, left=True, right=True)
    
    ax5.set_ylabel("Height MSL [m]", fontsize=axis_label_size, fontweight=axis_label_weight)
    
    ax5.set_title(site+' Radial Wind Cross Section', fontsize=axis_label_size, fontweight=axis_label_weight)
    
    # WRF radar cross section
    ######################
    # scenarios = ['Fuelx8', 'Fuelx4', 'Fuelx1']
    axvals = [ax6, ax7, ax8]
    for xx in reversed(range(len(plot_wrfWindrhi))):
        
        axVal = axvals[xx]
        wrf_xsec = axVal.pcolormesh(x_array, interp_height, plot_wrfWindrhi[xx], vmin=-25, vmax=25, cmap=cmap1, zorder=0)
        # plot terrain
        axVal.plot(x_array,ter_cross,color='k',linewidth=4.,zorder=2)
        axVal.fill_between(x_array,ter_cross,y2=0, color='sienna', zorder=1)
        # Plot fire area
        axVal.plot(lfn_x,lfn_y,color='k',linewidth=perim_linewidth, zorder=3)
        axVal.plot(lfn_x,lfn_y,color='orangered',linewidth=perim_linewidth-2, zorder=4)
        
        # Set the x-ticks to use latitude and longitude labels.
        tick_space = 90 #higher values = fewer ticks
        x_ticks = np.arange(len(lat_cross))
        x_labels = ['{:.2f}, {:.2f}'.format(l1,l2) for l1,l2 in coord_pairs]
        axVal.set_xticks(x_ticks[tick_space::tick_space])
        # axVal.set_xticklabels(x_labels[tick_space::tick_space], rotation=20, horizontalalignment='center')
        
        # Set x/y lims
        axVal.set_xlim(x_array[0],x_array[-1])
        axVal.set_ylim(0,9000)
        
        # Set border and ticks around plot
        for axis in ['top', 'bottom', 'left', 'right']:
            axVal.spines[axis].set_linewidth(2)  # change width
            axVal.spines[axis].set_zorder(99)
    
        if (axVal == ax8):
            axVal.set_xticklabels(x_labels[tick_space::tick_space], rotation=20, horizontalalignment='center')
            axVal.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, rotation=20, bottom=True, top=True)
        else:
            axVal.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=tick_label_size, labelcolor='white',rotation=20, bottom=True, top=True)
        axVal.set_yticklabels([0,2000,4000,6000,8000])
        axVal.yaxis.set_tick_params(width=2,length=5,direction="in", labelsize=tick_label_size, zorder=99, left=True, right=True)
        
        # Title and axis labels
        axVal.set_ylabel("Height MSL [m]", fontsize=axis_label_size, fontweight=axis_label_weight)
        axVal.set_xlabel("Longitude, Latitude", fontsize=axis_label_size, fontweight=axis_label_weight)
        # plt.colorbar(wrf_xsec)
        
        # ax4.set_title('Simulated '+site+' Radial Wind Cross Section', fontsize=12, fontweight='bold')
        
    cbar = fig.colorbar(wrf_xsec, orientation='horizontal', cax=ax9)
    cbar_lvls = np.arange(-25,30,5)
    cbar.set_ticks(cbar_lvls)
    cbar.set_ticklabels(cbar_lvls.astype(int))
    cbar.ax.tick_params(labelsize=tick_label_size)
    cbar.set_label(r'Radial Velocity [m $\mathregular{s^{-1}}$]', fontsize=axis_label_size, fontweight=axis_label_weight)
    
    # plt.show()
    # sys.exit()
    
    # fig.text(.5,.92,timestamp.strftime('%Y-%m-%d %X'), fontsize=14, fontweight='bold')
    
    # plt.show()
    # fig.savefig('/Users/matthewroberts/Desktop/RAD.png', bbox_inches='tight', dpi=400)
    # sys.exit()
    
    plt.savefig(savepath+'/TEST'+timestamp.strftime('%Y-%m-%d %X')+'.png',bbox_inches='tight',dpi=260)
    plt.close('all')
    print('Saved '+savepath+'/'+timestamp.strftime('%Y-%m-%d %X')+'.png')
    # sys.exit()



#%%
"""


############################### REGRIDDING SECTION ###############################



########################################
# Manipulate model data
########################################
# Create 3d arr of radials, lons, lats
radial = calculate_initial_compass_bearing((center_lat,center_lon), (lat,lon))

# center_x = 6371.*np.deg2rad(center_lon)*np.cos(np.mean(np.deg2rad(center_lat)))
# center_y = 6371.*np.deg2rad(center_lat)

# # Convert lat/lon to xy distance from radar tower
# wrf_x = 6371.*np.deg2rad(lon)*np.cos(np.mean(np.deg2rad(center_lat)))
# wrf_y = 6371.*np.deg2rad(lat)
# lon = wrf_x
# lat = wrf_y

wgs84_geod = Geod(ellps='WGS84')
#Get distance between pairs of lat-lon points
def Distance(lat1,lon1,lat2,lon2):
  az12,az21,dist = wgs84_geod.inv(lon1,lat1,lon2,lat2) #Yes, this order is correct
  return dist

cen_lat = lon*0
cen_lon = lon*0
cen_lat = cen_lat + center_lat
cen_lon = cen_lon + center_lon

d = Distance(lat,lon,cen_lat,cen_lon)
wrf_x = np.sin(np.deg2rad(radial))*d
wrf_y = np.cos(np.deg2rad(radial))*d

lat = wrf_y
lon = wrf_x

# print(lon[:10,0])

# proj = ccrs.LambertConformal(central_latitude = mod_cen_lat,
#                              central_longitude = mod_cen_lon)

# transform = proj.transform_points(ccrs.PlateCarree(), lon, lat)

# lon = transform[..., 0]
# lat = transform[..., 1]

# print(lon[:10,0])
# sys.exit()

# sys.exit()
radial1 = radial
lon1 = lon
lat1 = lat
for dim in range(len(z)-1):
    radial1 = np.dstack((radial1,radial))
    lon1 = np.dstack((lon1,lon))
    lat1 = np.dstack((lat1,lat))
radial = radial1.transpose((2,0,1))
lon = lon1.transpose((2,0,1))
lat = lat1.transpose((2,0,1))
# Create 3d arr of heights
_,_,z1 = np.meshgrid(lat[0,0,:],lat[0,:,0],z)
z = z1.transpose((2,0,1))
# Calc radial winds
rad_wind = ws*np.cos(np.deg2rad(radial-wd))*-1.

print('Interpolating...')
for s in range(radar_obj.nsweeps):
    if (radar_obj.get_gate_lat_lon_alt(s)[0][:,0].shape[0] == 720):
        if (s == 0):
            rlat = radar_obj.get_gate_lat_lon_alt(s)[0]
            rlon = radar_obj.get_gate_lat_lon_alt(s)[1]
            rz = radar_obj.get_gate_lat_lon_alt(s)[2]
            rvel = radar_obj.fields['velocity']['data'][radar_obj.sweep_start_ray_index['data'][s]:radar_obj.sweep_end_ray_index['data'][s]+1,:]
        else:
            rlat1 = radar_obj.get_gate_lat_lon_alt(s)[0]
            rlon1 = radar_obj.get_gate_lat_lon_alt(s)[1]
            rz1 = radar_obj.get_gate_lat_lon_alt(s)[2]
            rvel1 = radar_obj.fields['velocity']['data'][radar_obj.sweep_start_ray_index['data'][s]:radar_obj.sweep_end_ray_index['data'][s]+1,:]

            rlat = np.dstack((rlat,rlat1))
            rlon = np.dstack((rlon,rlon1))
            rz = np.dstack((rz,rz1))
            rvel = np.dstack((rvel,rvel1))

# Velocity scans, lat/lon/z
rlat = rlat.transpose((2,0,1))
rlon = rlon.transpose((2,0,1))
rz = rz.transpose((2,0,1))
rvel = rvel.transpose((2,0,1))

# # Convert lat/lon to xy distance from radar tower
# rad_x = 6371.*np.deg2rad(rlon)*np.cos(np.mean(np.deg2rad(center_lat)))
# rad_y = 6371.*np.deg2rad(rlat)
# rlon = rad_x
# rlat = rad_y

radial = calculate_initial_compass_bearing((center_lat,center_lon), (rlat,rlon))
cen_lat = rlon*0
cen_lon = rlon*0
cen_lat = cen_lat + center_lat
cen_lon = cen_lon + center_lon
d = Distance(rlat,rlon,cen_lat,cen_lon)
rad_x = np.sin(np.deg2rad(radial))*d
rad_y = np.cos(np.deg2rad(radial))*d

rlat = rad_y
rlon = rad_x

# idxs = np.where(rlat[0,:,100] > lat.min())
# rlat = rlat[:,idxs[0],:]
# rlon = rlon[:,idxs[0],:]
# rz = rz[:,idxs[0],:]

# Crop data
########################################
# crop_area = [-50000,60000,-10000,90000]
crop_area = [-5000,40000,-5000,40000]
# Only look at model data within camp fire area to speed up interp
i0,i1,j0,j1 = bbox2ij(lon[0,:,:],lat[0,:,:],crop_area)
lon = lon[:,j0:j1,i0:i1]
lat = lat[:,j0:j1,i0:i1]
z = z[:,j0:j1,i0:i1]
rad_wind = rad_wind[:,j0:j1,i0:i1]

# Only look at model data within camp fire area to speed up interp
i0,i1,j0,j1 = bbox2ij(rlon[0,:,:],rlat[0,:,:],crop_area)
rlon = rlon[:,j0:j1,i0:i1]
rlat = rlat[:,j0:j1,i0:i1]
rz = rz[:,j0:j1,i0:i1]
rvel = rvel[:,j0:j1,i0:i1]
########################################

plt.pcolormesh(rlon[3,:,:],rlat[3,:,:],rvel[3,:,:])
plt.show()

# velshp = rvel.shape
# print(rvel.shape)
# rvel = rvel[rvel > -64]
# print(rvel.shape)
# rvel = rvel[rvel < 64]
# print(rvel.shape)
#%%
# for s in range(radar_obj.nsweeps):
#     print(radar_obj.get_elevation(s)[0])

# # v1 = np.asarray([])
# for v in [1,3]:#np.arange(len(rvel[:,0,0])-1)+1:
#     if (v == 1):
#         v1 = rvel[v,:,:].flatten()
#         v1 = v1[v1 > -64]
#         print(v1.shape)
#     else:
#         v2 = rvel[v,:,:].flatten()
#         v2 = v2[v2 > -64]
#         v1 = np.vstack((v1,v2))
#     # sys.exit()

# sys.exit()

lvl = 3
azi = 130
rlon = rlon[lvl,azi:azi+140,:]
rlat = rlat[lvl,azi:azi+140,:]
rz = rz[lvl,azi:azi+140,:]
rvel = rvel[lvl,azi:azi+140,:]
print(rvel.shape)

rlon[rvel <= -64]==np.nan
rlat[rvel <= -64]==np.nan
rz[rvel <= -64]==np.nan
rvel[rvel <= -64]==np.nan
print(rvel.shape)

plt.pcolormesh(rlon,rlat,rvel)
plt.show()

sys.exit()

# velshp = rvel.shape
# rvel = rvel.flatten()
# print(rvel.shape)
# rvel = rvel[np.logical_not(np.isnan(rvel))]
# print(rvel.shape)

# sys.exit()

# Data coordinates
points = np.array( (z.flatten(), lat.flatten(), lon.flatten()) ).T
# Data values @ above coords
values = rad_wind.flatten()

# print(points.shape)
# sys.exit()

# New data values on interp grid
rad_wind = griddata( points, values, (rz,rlat,rlon), method='linear' )

end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))

# sys.exit()
#%%
########################################
# Plotting
########################################
print('Plotting...')
fig = plt.figure(figsize=(12, 5))
# plot super resolution reflectivity
ax = fig.add_subplot(122)
display = pyart.graph.RadarDisplay(radar_obj)
display.plot('velocity', 4, title='Radar Vel (elev 1)',
              vmin=-20, vmax=20, cmap='seismic', colorbar_label='', ax=ax,zorder=9)
display.set_limits(xlim=(-45, 65), ylim=(-5, 85), ax=ax)
ax.scatter(0, 0, marker='*', s=500, c='magenta', edgecolors=['k'], zorder=12)

ax2 = fig.add_subplot(121)
ax2.pcolormesh(rlon[:,:],rlat[:,:],rad_wind[:,:],vmin=-20,vmax=20,cmap='seismic')
ax2.set_title('Model Vel (elev 1)')
ax2.scatter(center_lon, center_lat, marker='*', s=500, c='magenta', edgecolors=['k'], zorder=11)
ax2.set_xlim(-45000, 65000)
ax2.set_ylim(-5000, 85000)

plt.show

"""

end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))